import torch.nn.functional as F
import numpy as np
import torch
from sklearn.cluster import KMeans
import os


# %%
class ClusteringTask(torch.nn.Module):
    num_clusters = 10

    def __init__(self, data, embedding_size, device):
        super(ClusteringTask, self).__init__()
        self.name = 'clustering'
        self.data = data
        self.dataset_names = [dataset.name for dataset in self.data.datasets]
        self.device = device
        self.pseudo_labels = []
        self.create_pseudo_labels()
        self.predictor = torch.nn.Linear(embedding_size, self.num_clusters).to(self.device)

    def get_loss(self, embeddings, dataset_name):
        index = self.dataset_names.index(dataset_name)
        embeddings = self.predictor(embeddings)
        output = F.log_softmax(embeddings, dim=1)
        loss = F.nll_loss(output, self.pseudo_labels[index])
        return loss

    def create_pseudo_labels(self):
        for dataset in self.data.datasets:
            cluster_file = './saved/' + dataset.name + '_cluster_%s.npy' % self.num_clusters
            if not os.path.exists(cluster_file):
                print('Performing clustering with K-Means on ' + dataset.name)
                kmeans = KMeans(n_clusters=self.num_clusters, random_state=0).fit(dataset.data.x)
                cluster_labels = kmeans.labels_
                np.save(cluster_file, cluster_labels)
                self.pseudo_labels.append(torch.LongTensor(cluster_labels).to(self.device))
            else:
                cluster_labels = np.load(cluster_file)
                self.pseudo_labels.append(torch.LongTensor(cluster_labels).to(self.device))
